{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import polars as pl\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "# CHANGE\n",
    "PATH = '' # root data folder\n",
    "PATH_OUTPUT = PATH + '' # curated data\n",
    "PATH_ACCURACY = PATH + '' # accuracy evaluation data (manually verified)\n",
    "PATH_CHARTS = PATH + '' # charts\n",
    "PATH_AGGREGATE = '' # aggregate data (provided in this package)\n",
    "\n",
    "pl.Config.set_fmt_str_lengths(100);\n",
    "\n",
    "MATCHES = 'df_matches_0.csv' # name of the manually evaluated matches"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_matches_checked = (\n",
    "    pl.read_csv(PATH_ACCURACY + MATCHES)\n",
    "    .filter(pl.all(pl.col('^correct_match_.*$').is_not_null()))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute accuracy\n",
    "\n",
    "def separate_df_by_cutoff(df, cutoff, location):\n",
    "    return (\n",
    "        df.filter(pl.col(f'score_{location}') < cutoff),\n",
    "        df.filter(pl.col(f'score_{location}') >= cutoff)\n",
    "    )\n",
    "\n",
    "def count_pos_matches(df, location):\n",
    "    counts = {}\n",
    "\n",
    "    counts['true_pos'] = sum(df[f'correct_match_{location}'])\n",
    "    counts['false_pos'] = len(df) - counts['true_pos']\n",
    "\n",
    "    weighted_total = sum(df['job_count'])\n",
    "    counts['weighted_true_pos'] = sum(\n",
    "        (df[f'correct_match_{location}'] * df['job_count'])\n",
    "        )\n",
    "    counts['weighted_false_pos'] = weighted_total - counts['weighted_true_pos']\n",
    "\n",
    "    return counts\n",
    "\n",
    "def count_neg_matches(df):\n",
    "    counts = {}\n",
    "\n",
    "    counts['false_neg'] = sum(\n",
    "        df.select(pl.any(pl.col('^correct_match_.*$'))).to_series()\n",
    "        )\n",
    "    counts['true_neg'] = len(df) - counts['false_neg']\n",
    "\n",
    "    counts['weighted_false_neg'] = sum(\n",
    "        df.select(pl.any(pl.col('^correct_match_.*$'))).to_series() *\n",
    "        df['job_count']\n",
    "        )\n",
    "    counts['weighted_true_neg'] = sum(df['job_count']) - counts['weighted_false_neg']\n",
    "\n",
    "    return counts\n",
    "\n",
    "def sum_up_results(location_counts):\n",
    "    counts = defaultdict(int)\n",
    "    for location in location_counts:\n",
    "        for k in location_counts[location]:\n",
    "            counts[k] += location_counts[location][k]\n",
    "    counts['total'] = sum([counts[k] for k in counts if 'weighted' not in k])\n",
    "    counts['weighted_total'] = sum([counts[k] for k in counts if 'weighted' in k])\n",
    "    return counts\n",
    "\n",
    "def compute_true_pos_rate(count, weighted=False):\n",
    "    if weighted:\n",
    "        return count['weighted_true_pos'] / (count['weighted_true_pos'] + count['weighted_false_neg'])\n",
    "    else:\n",
    "        return count['true_pos'] / (count['true_pos'] + count['false_neg'])\n",
    "\n",
    "def compute_false_pos_rate(count, weighted=False):\n",
    "    if weighted:\n",
    "        return count['weighted_false_pos'] / (count['weighted_false_pos'] + count['weighted_true_neg'])\n",
    "    else:\n",
    "        return count['false_pos'] / (count['false_pos'] + count['true_neg'])\n",
    "\n",
    "\n",
    "def compute_accuracy(count, weighted=False):\n",
    "    if weighted:\n",
    "        return (count['weighted_true_pos'] + count['weighted_true_neg']) / count['weighted_total']\n",
    "    else:\n",
    "        return (count['true_pos'] + count['true_neg']) / count['total']\n",
    "\n",
    "def compute_perc_matched(count, weighted=False):\n",
    "    if weighted:\n",
    "        return (count['weighted_true_pos'] + count['weighted_false_pos']) / count['weighted_total']\n",
    "    else:\n",
    "        return (count['true_pos'] + count['false_pos']) / count['total']\n",
    "\n",
    "def compute_precision(count, weighted=False):\n",
    "    try:\n",
    "        if weighted:\n",
    "            return count['weighted_true_pos'] / (count['weighted_true_pos'] + count['weighted_false_pos'])\n",
    "        else:\n",
    "            return count['true_pos'] / (count['true_pos'] + count['false_pos'])\n",
    "    except ZeroDivisionError:\n",
    "        return 0\n",
    "\n",
    "def compute_recall(count, weighted=False):\n",
    "    try:\n",
    "        if weighted:\n",
    "            return count['weighted_true_pos'] / (count['weighted_true_pos'] + count['weighted_false_neg'])\n",
    "        else:\n",
    "            return count['true_pos'] / (count['true_pos'] + count['false_neg'])\n",
    "    except ZeroDivisionError:\n",
    "        return 0\n",
    "\n",
    "def count_cases(df_matches_checked, cutoffs):\n",
    "    \"\"\"for given cutoffs at city, province, national level, \n",
    "    counts the number of true positives, false positives, true negatives, \n",
    "    and false negatives, weighted by job counts and unweighted\"\"\"\n",
    "\n",
    "    df_below_cut = df_matches_checked.clone()\n",
    "    location_counts = {}\n",
    "    cutoffs = dict(zip(['city', 'province', 'national'], cutoffs))\n",
    "    for location, cutoff in cutoffs.items():\n",
    "        df_below_cut, df_above_cut = separate_df_by_cutoff(df_below_cut, cutoff, location)\n",
    "        location_counts[location] = count_pos_matches(df_above_cut, location)\n",
    "\n",
    "    location_counts['others'] = count_neg_matches(df_below_cut)\n",
    "\n",
    "    return sum_up_results(location_counts)\n",
    "\n",
    "def compute_results(df_matches_checked, all_cutoffs):\n",
    "    results = defaultdict(list)\n",
    "    accuracy, weighted_accuracy, perc_matched, weighted_perc_matched = [], [], [], []\n",
    "    for cutoffs in all_cutoffs:\n",
    "        counts = count_cases(df_matches_checked, cutoffs)\n",
    "        results['cutoff'].append(cutoffs[0])\n",
    "        results['accuracy'].append(compute_accuracy(counts))\n",
    "        results['weighted_accuracy'].append(compute_accuracy(counts, weighted=True))\n",
    "        results['perc_matched'].append(compute_perc_matched(counts))\n",
    "        results['weighted_perc_matched'].append(compute_perc_matched(counts, weighted=True))\n",
    "        results['true_pos_rate'].append(compute_true_pos_rate(counts))\n",
    "        results['weighted_true_pos_rate'].append(compute_true_pos_rate(counts, weighted=True))\n",
    "        results['false_pos_rate'].append(compute_false_pos_rate(counts))\n",
    "        results['weighted_false_pos_rate'].append(compute_false_pos_rate(counts, weighted=True))\n",
    "        results['precision'].append(compute_precision(counts))\n",
    "        results['weighted_precision'].append(compute_precision(counts, weighted=True))\n",
    "        results['recall'].append(compute_recall(counts))\n",
    "        results['weighted_recall'].append(compute_recall(counts, weighted=True))\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_cutoffs = [[cutoff, cutoff, cutoff] for cutoff in np.arange(-1e-9, 1.00, 0.01)]\n",
    "results = compute_results(df_matches_checked, all_cutoffs)\n",
    "\n",
    "# Create a DataFrame for the results\n",
    "df_plot = pd.DataFrame(\n",
    "    {\n",
    "        'accuracy': results['accuracy'],\n",
    "        'weighted accuracy': results['weighted_accuracy'],\n",
    "        'percentage matched': results['perc_matched'],\n",
    "        'weighted percentage matched': results['weighted_perc_matched']\n",
    "    },\n",
    "    index=[cutoff[0] for cutoff in all_cutoffs],\n",
    ")\n",
    "\n",
    "df_plot.to_csv(PATH_AGGREGATE + 'accuracy_figure.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pick cutoff s.t. weighted accuracy is closest to 0.9\n",
    "\n",
    "(\n",
    "    pl.DataFrame(\n",
    "        {\n",
    "            'weighted_accuracy': results['weighted_accuracy'],\n",
    "            'cutoff': results['cutoff']\n",
    "        }\n",
    "    )\n",
    "    .with_columns(((pl.col('weighted_accuracy') - 0.9).abs()).alias('distance_to_0.9'))\n",
    "    .sort('distance_to_0.9')\n",
    "    .head(4)\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.12 ('env_indeed2')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "26da711ab583a058e13fb43990c4acfd219f633b35c618763404a0fc5624b2b9"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
